{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "19b9f9be-bff1-47dc-8be0-576ef6a557ff",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "source": [
    "# Generating News Headlines using GPT2\n",
    "## GPT2\n",
    "GPT-2 model was released as part of the work titled “Language Models are Unsupervised Multi-task Learners” in 2019. The largest GPT-2 variant is a huge 1.5B parameter transformer-based model which the model was able to perform remarkably well of various NLP tasks. The most striking aspect of this work is that the authors showcase how a model trained in an unsupervised fashion (language modeling) achieves state-of-the-art performance in zero-shot setting.\n",
    "\n",
    "## HuggingFace Transformers\n",
    "One of the most propular python packages to work with Transformer based NLP models. Huggingface transformers is a high-level API to easily load, fine-tune and re-train models such as GPT2, BERT, T5 and so on\n",
    "\n",
    "## Fake Headlines\n",
    "ABC-News Dataset is a dataset of a million headlines available here collected over a period of 17 years. We will make use of this dataset to fine-tune the GPT2 model. Once fine-tuned we will use it to generate some fake headlines\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PacktPublishing/Generative-AI-with-Python-and-PyTorch-Second-Edition/blob/master/ch_04/03_gpt2_headlines_generator.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d4f3d782-174b-4c55-8d11-8b9a52c34fe9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip3 install scikit-learn==1.5.1\n",
    "# !pip3 install transformers==4.42.4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a457e201-f112-4bbe-9ebb-04a39961ce73",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "import torch\n",
    "from transformers import pipeline\n",
    "from transformers import AutoTokenizer\n",
    "from transformers import TextDataset,DataCollatorForLanguageModeling\n",
    "from transformers import Trainer, TrainingArguments,AutoModelForCausalLM"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fffef4bc-e884-4509-9636-eedb039e8961",
   "metadata": {},
   "source": [
    "### Prepare Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "791b7720-872d-4da9-84a2-d7000e7946fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# download from https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/SYBGZL\n",
    "# !unzip abcnews.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "15c18110-cd6a-4695-97fe-17f279ec87e3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1241692, 2)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "news = pd.read_csv('abcnews-date-text-sample.csv')\n",
    "news.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "ce909286-1c48-45bf-8227-51dccaabca6f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>publish_date</th>\n",
       "      <th>headline_text</th>\n",
       "      <th>line_length</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>20030219</td>\n",
       "      <td>aba decides against community broadcasting lic...</td>\n",
       "      <td>50</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>20030219</td>\n",
       "      <td>act fire witnesses must be aware of defamation</td>\n",
       "      <td>46</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>20030219</td>\n",
       "      <td>a g calls for infrastructure protection summit</td>\n",
       "      <td>46</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>20030219</td>\n",
       "      <td>air nz staff in aust strike for pay rise</td>\n",
       "      <td>40</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>20030219</td>\n",
       "      <td>air nz strike to affect australian travellers</td>\n",
       "      <td>45</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   publish_date                                      headline_text  \\\n",
       "0      20030219  aba decides against community broadcasting lic...   \n",
       "1      20030219     act fire witnesses must be aware of defamation   \n",
       "2      20030219     a g calls for infrastructure protection summit   \n",
       "3      20030219           air nz staff in aust strike for pay rise   \n",
       "4      20030219      air nz strike to affect australian travellers   \n",
       "\n",
       "   line_length  \n",
       "0           50  \n",
       "1           46  \n",
       "2           46  \n",
       "3           40  \n",
       "4           45  "
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "news.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "9e9205e3-0921-45c1-af1e-68889ee2342d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Axes: >"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAkIAAAGdCAYAAAD+JxxnAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAuAElEQVR4nO3df1iUdb7/8RcgDGCCogvIEZWtTv7WwiT6dSyR0bg6ubltPzxmZnblBW3IOaa2SqjtstH6s0hOW2p7rexq17VZaQeZMDVX1ETZ1NKt1s06NtjJH5OYw8jM94+9uL+OIEgLjM7n+bguLp37fnPP+/0Rxhf3zD2E+Hw+nwAAAAwUGugGAAAAAoUgBAAAjEUQAgAAxiIIAQAAYxGEAACAsQhCAADAWAQhAABgLIIQAAAwVqdAN3A583q9Onr0qLp06aKQkJBAtwMAAC6Bz+fTd999p6SkJIWGNn/OhyDUjKNHjyo5OTnQbQAAgB/gyy+/VK9evZqtIQg1o0uXLpL+sZAxMTEB7qZteTwelZeXKzMzU+Hh4YFuJyBMXwPmN3t+iTUwfX4peNfA5XIpOTnZ+n+8OQShZjQ8HRYTExOUQSg6OloxMTFB9cXfGqavAfObPb/EGpg+vxT8a3ApL2vhxdIAAMBYBCEAAGAsghAAADAWQQgAABiLIAQAAIxFEAIAAMYiCAEAAGMRhAAAgLEIQgAAwFgEIQAAYCyCEAAAMBZBCAAAGIsgBAAAjEUQAgAAxuoU6AYAIJAGFWyUuz4k0G1csr//OivQLQBBhTNCAADAWAQhAABgLIIQAAAwFkEIAAAYiyAEAACMRRACAADGIggBAABjEYQAAICxCEIAAMBYvLM0cBnqO2tDu9+HLcynohFt987KvOMxgCsRQQhB72Khoq2DAADgysNTYwAAwFgEIQAAYCyCEAAAMBZBCAAAGIsgBAAAjEUQAgAAxiIIAQAAYxGEAACAsQhCAADAWAQhAABgLIIQAAAwFkEIAAAYiyAEAACMRRACAADGIggBAABjEYQAAICxOgW6AQDBoe+sDYFuoVVsYT4VjQh0FwACjTNCAADAWAQhAABgLIIQAAAwFkEIAAAYiyAEAACMRRACAADGIggBAABjEYQAAICxCEIAAMBYBCEAAGCsVgWhwsJC3XjjjerSpYvi4+M1btw4HTp0yK9m5MiRCgkJ8ft44okn/GqOHDmirKwsRUdHKz4+XjNmzNC5c+f8ajZv3qwbbrhBNptN11xzjVatWtWon+LiYvXt21eRkZFKS0vTrl27/PafPXtW2dnZ6t69u6666iqNHz9eNTU1rRkZAAAEsVYFoS1btig7O1s7duyQw+GQx+NRZmamamtr/eqmTp2qr7/+2vooKiqy9tXX1ysrK0t1dXXavn27Xn/9da1atUr5+flWzeHDh5WVlaU77rhD1dXVys3N1WOPPaaNGzdaNWvWrFFeXp6effZZ7dmzR0OHDpXdbtexY8esmunTp+udd97RG2+8oS1btujo0aO69957W71IAAAgOLXql66WlZX53V61apXi4+NVVVWl22+/3doeHR2txMTEJo9RXl6ujz/+WO+9954SEhI0bNgwLViwQDNnzlRBQYEiIiJUUlKilJQULVy4UJLUv39/bdu2TYsXL5bdbpckLVq0SFOnTtXkyZMlSSUlJdqwYYNWrFihWbNm6dSpU3rttddUWlqqO++8U5K0cuVK9e/fXzt27NBNN93UmtEBAEAQ+qd++/ypU6ckSXFxcX7bV69erd///vdKTEzU3Xffrblz5yo6OlqSVFlZqcGDByshIcGqt9vtmjZtmg4cOKDrr79elZWVysjI8Dum3W5Xbm6uJKmurk5VVVWaPXu2tT80NFQZGRmqrKyUJFVVVcnj8fgdp1+/furdu7cqKyubDEJut1tut9u67XK5JEkej0cej6fV63M5a5gn2OZqii3M1/T2UJ/fn6Zh/itz/rb8njXpcaApps8vBe8atGaeHxyEvF6vcnNzdcstt2jQoEHW9oceekh9+vRRUlKSPvroI82cOVOHDh3Sn/70J0mS0+n0C0GSrNtOp7PZGpfLpe+//14nTpxQfX19kzUHDx60jhEREaGuXbs2qmm4nwsVFhZq3rx5jbaXl5dbQS7YOByOQLfQ7opGNL9/wXBvxzRymWL+K2v+d999t82PacLjQHNMn18KvjU4c+bMJdf+4CCUnZ2t/fv3a9u2bX7bH3/8cevvgwcPVs+ePTVq1Ch9/vnnuvrqq3/o3XWI2bNnKy8vz7rtcrmUnJyszMxMxcTEBLCztufxeORwODR69GiFh4cHup12NahgY5PbbaE+LRju1dzdoXJ7Qzq4q8Bj/itz/v0F9jY7lkmPA00xfX4peNeg4RmdS/GDglBOTo7Wr1+vrVu3qlevXs3WpqWlSZI+++wzXX311UpMTGx0dVfDlVwNrytKTExsdHVXTU2NYmJiFBUVpbCwMIWFhTVZc/4x6urqdPLkSb+zQufXXMhms8lmszXaHh4eHlRfIOcL5tkauOub/0/O7Q1psSaYMf+VNX97fL+a8DjQHNPnl4JvDVozS6uuGvP5fMrJydGbb76pTZs2KSUlpcXPqa6uliT17NlTkpSenq59+/b5Xd3lcDgUExOjAQMGWDUVFRV+x3E4HEpPT5ckRUREKDU11a/G6/WqoqLCqklNTVV4eLhfzaFDh3TkyBGrBgAAmK1VZ4Sys7NVWlqqt956S126dLFeaxMbG6uoqCh9/vnnKi0t1V133aXu3bvro48+0vTp03X77bdryJAhkqTMzEwNGDBAEydOVFFRkZxOp+bMmaPs7GzrbMwTTzyhl156SU8//bQeffRRbdq0SWvXrtWGDRusXvLy8jRp0iQNHz5cI0aM0JIlS1RbW2tdRRYbG6spU6YoLy9PcXFxiomJ0ZNPPqn09HSuGAMAAJJaGYSWL18u6R9vmni+lStX6pFHHlFERITee+89K5QkJydr/PjxmjNnjlUbFham9evXa9q0aUpPT1fnzp01adIkzZ8/36pJSUnRhg0bNH36dC1dulS9evXSq6++al06L0n333+/vvnmG+Xn58vpdGrYsGEqKyvzewH14sWLFRoaqvHjx8vtdstut+vll19u1QIBAIDg1aog5PM1f5lpcnKytmzZ0uJx+vTp0+KVDyNHjtTevXubrcnJyVFOTs5F90dGRqq4uFjFxcUt9gQAAMzD7xoDAADGIggBAABjEYQAAICxCEIAAMBYBCEAAGAsghAAADAWQQgAABiLIAQAAIxFEAIAAMYiCAEAAGMRhAAAgLEIQgAAwFgEIQAAYCyCEAAAMBZBCAAAGIsgBAAAjEUQAgAAxiIIAQAAYxGEAACAsQhCAADAWAQhAABgLIIQAAAwFkEIAAAYiyAEAACMRRACAADGIggBAABjEYQAAICxCEIAAMBYBCEAAGAsghAAADAWQQgAABiLIAQAAIxFEAIAAMYiCAEAAGMRhAAAgLEIQgAAwFgEIQAAYCyCEAAAMBZBCAAAGIsgBAAAjEUQAgAAxiIIAQAAYxGEAACAsQhCAADAWAQhAABgLIIQAAAwFkEIAAAYiyAEAACMRRACAADGIggBAABjEYQAAICxWhWECgsLdeONN6pLly6Kj4/XuHHjdOjQIb+as2fPKjs7W927d9dVV12l8ePHq6amxq/myJEjysrKUnR0tOLj4zVjxgydO3fOr2bz5s264YYbZLPZdM0112jVqlWN+ikuLlbfvn0VGRmptLQ07dq1q9W9AAAAc7UqCG3ZskXZ2dnasWOHHA6HPB6PMjMzVVtba9VMnz5d77zzjt544w1t2bJFR48e1b333mvtr6+vV1ZWlurq6rR9+3a9/vrrWrVqlfLz862aw4cPKysrS3fccYeqq6uVm5urxx57TBs3brRq1qxZo7y8PD377LPas2ePhg4dKrvdrmPHjl1yLwAAwGydWlNcVlbmd3vVqlWKj49XVVWVbr/9dp06dUqvvfaaSktLdeedd0qSVq5cqf79+2vHjh266aabVF5ero8//ljvvfeeEhISNGzYMC1YsEAzZ85UQUGBIiIiVFJSopSUFC1cuFCS1L9/f23btk2LFy+W3W6XJC1atEhTp07V5MmTJUklJSXasGGDVqxYoVmzZl1SLwAAwGytCkIXOnXqlCQpLi5OklRVVSWPx6OMjAyrpl+/furdu7cqKyt10003qbKyUoMHD1ZCQoJVY7fbNW3aNB04cEDXX3+9Kisr/Y7RUJObmytJqqurU1VVlWbPnm3tDw0NVUZGhiorKy+5lwu53W653W7rtsvlkiR5PB55PJ4ftEaXq4Z5gm2uptjCfE1vD/X5/Wka5r8y52/L71mTHgeaYvr8UvCuQWvm+cFByOv1Kjc3V7fccosGDRokSXI6nYqIiFDXrl39ahMSEuR0Oq2a80NQw/6Gfc3VuFwuff/99zpx4oTq6+ubrDl48OAl93KhwsJCzZs3r9H28vJyRUdHX2wprmgOhyPQLbS7ohHN718w3NsxjVymmP/Kmv/dd99t82Oa8DjQHNPnl4JvDc6cOXPJtT84CGVnZ2v//v3atm3bDz3EZWf27NnKy8uzbrtcLiUnJyszM1MxMTEB7KzteTweORwOjR49WuHh4YFup10NKtjY5HZbqE8Lhns1d3eo3N6QDu4q8Jj/ypx/f4G9zY5l0uNAU0yfXwreNWh4RudS/KAglJOTo/Xr12vr1q3q1auXtT0xMVF1dXU6efKk35mYmpoaJSYmWjUXXt3VcCXX+TUXXt1VU1OjmJgYRUVFKSwsTGFhYU3WnH+Mlnq5kM1mk81ma7Q9PDw8qL5AzhfMszVw1zf/n5zbG9JiTTBj/itr/vb4fjXhcaA5ps8vBd8atGaWVl015vP5lJOTozfffFObNm1SSkqK3/7U1FSFh4eroqLC2nbo0CEdOXJE6enpkqT09HTt27fP7+ouh8OhmJgYDRgwwKo5/xgNNQ3HiIiIUGpqql+N1+tVRUWFVXMpvQAAALO16oxQdna2SktL9dZbb6lLly7Wa21iY2MVFRWl2NhYTZkyRXl5eYqLi1NMTIyefPJJpaenWy9OzszM1IABAzRx4kQVFRXJ6XRqzpw5ys7Ots7GPPHEE3rppZf09NNP69FHH9WmTZu0du1abdiwweolLy9PkyZN0vDhwzVixAgtWbJEtbW11lVkl9ILAAAwW6uC0PLlyyVJI0eO9Nu+cuVKPfLII5KkxYsXKzQ0VOPHj5fb7ZbdbtfLL79s1YaFhWn9+vWaNm2a0tPT1blzZ02aNEnz58+3alJSUrRhwwZNnz5dS5cuVa9evfTqq69al85L0v33369vvvlG+fn5cjqdGjZsmMrKyvxeQN1SLwAAwGytCkI+X8uXmUZGRqq4uFjFxcUXrenTp0+LVz6MHDlSe/fubbYmJydHOTk5/1QvAADAXPyuMQAAYCyCEAAAMBZBCAAAGIsgBAAAjEUQAgAAxiIIAQAAYxGEAACAsQhCAADAWAQhAABgLIIQAAAwFkEIAAAYiyAEAACMRRACAADGIggBAABjEYQAAICxCEIAAMBYBCEAAGAsghAAADAWQQgAABiLIAQAAIxFEAIAAMYiCAEAAGMRhAAAgLEIQgAAwFgEIQAAYCyCEAAAMBZBCAAAGIsgBAAAjEUQAgAAxiIIAQAAYxGEAACAsQhCAADAWAQhAABgLIIQAAAwFkEIAAAYiyAEAACMRRACAADGIggBAABjEYQAAICxCEIAAMBYBCEAAGAsghAAADAWQQgAABiLIAQAAIxFEAIAAMYiCAEAAGMRhAAAgLEIQgAAwFgEIQAAYCyCEAAAMBZBCAAAGKvVQWjr1q26++67lZSUpJCQEK1bt85v/yOPPKKQkBC/jzFjxvjVHD9+XBMmTFBMTIy6du2qKVOm6PTp0341H330kW677TZFRkYqOTlZRUVFjXp544031K9fP0VGRmrw4MF69913/fb7fD7l5+erZ8+eioqKUkZGhj799NPWjgwAAIJUq4NQbW2thg4dquLi4ovWjBkzRl9//bX18Yc//MFv/4QJE3TgwAE5HA6tX79eW7du1eOPP27td7lcyszMVJ8+fVRVVaUXXnhBBQUFeuWVV6ya7du368EHH9SUKVO0d+9ejRs3TuPGjdP+/futmqKiIi1btkwlJSXauXOnOnfuLLvdrrNnz7Z2bAAAEIQ6tfYTxo4dq7FjxzZbY7PZlJiY2OS+Tz75RGVlZfrwww81fPhwSdKLL76ou+66S7/5zW+UlJSk1atXq66uTitWrFBERIQGDhyo6upqLVq0yApMS5cu1ZgxYzRjxgxJ0oIFC+RwOPTSSy+ppKREPp9PS5Ys0Zw5c3TPPfdIkn73u98pISFB69at0wMPPNDa0QEAQJBpdRC6FJs3b1Z8fLy6deumO++8U88995y6d+8uSaqsrFTXrl2tECRJGRkZCg0N1c6dO/WTn/xElZWVuv322xUREWHV2O12Pf/88zpx4oS6deumyspK5eXl+d2v3W63nqo7fPiwnE6nMjIyrP2xsbFKS0tTZWVlk0HI7XbL7XZbt10ulyTJ4/HI4/H88wtzGWmYJ9jmaootzNf09lCf35+mYf4rc/62/J416XGgKabPLwXvGrRmnjYPQmPGjNG9996rlJQUff7553rmmWc0duxYVVZWKiwsTE6nU/Hx8f5NdOqkuLg4OZ1OSZLT6VRKSopfTUJCgrWvW7ducjqd1rbza84/xvmf11TNhQoLCzVv3rxG28vLyxUdHX2pS3BFcTgcgW6h3RWNaH7/guHejmnkMsX8V9b8F74Wsi2Y8DjQHNPnl4JvDc6cOXPJtW0ehM4/0zJ48GANGTJEV199tTZv3qxRo0a19d21qdmzZ/udZXK5XEpOTlZmZqZiYmIC2Fnb83g8cjgcGj16tMLDwwPdTrsaVLCxye22UJ8WDPdq7u5Qub0hHdxV4DH/lTn//gJ7mx3LpMeBppg+vxS8a9DwjM6laJenxs734x//WD169NBnn32mUaNGKTExUceOHfOrOXfunI4fP269rigxMVE1NTV+NQ23W6o5f3/Dtp49e/rVDBs2rMlebTabbDZbo+3h4eFB9QVyvmCerYG7vvn/5NzekBZrghnzX1nzt8f3qwmPA80xfX4p+NagNbO0+/sIffXVV/r222+tMJKenq6TJ0+qqqrKqtm0aZO8Xq/S0tKsmq1bt/o9x+dwOHTdddepW7duVk1FRYXffTkcDqWnp0uSUlJSlJiY6Ffjcrm0c+dOqwYAAJit1UHo9OnTqq6uVnV1taR/vCi5urpaR44c0enTpzVjxgzt2LFDf//731VRUaF77rlH11xzjez2f5zO7d+/v8aMGaOpU6dq165d+vOf/6ycnBw98MADSkpKkiQ99NBDioiI0JQpU3TgwAGtWbNGS5cu9Xva6qmnnlJZWZkWLlyogwcPqqCgQLt371ZOTo4kKSQkRLm5uXruuef09ttva9++fXr44YeVlJSkcePG/ZPLBgAAgkGrnxrbvXu37rjjDut2QziZNGmSli9fro8++kivv/66Tp48qaSkJGVmZmrBggV+TzmtXr1aOTk5GjVqlEJDQzV+/HgtW7bM2h8bG6vy8nJlZ2crNTVVPXr0UH5+vt97Dd18880qLS3VnDlz9Mwzz+jaa6/VunXrNGjQIKvm6aefVm1trR5//HGdPHlSt956q8rKyhQZGdnasQEAQBBqdRAaOXKkfL6LX266cWPTL0w9X1xcnEpLS5utGTJkiD744INma+677z7dd999F90fEhKi+fPna/78+S32BAAAzMPvGgMAAMYiCAEAAGMRhAAAgLEIQgAAwFgEIQAAYCyCEAAAMBZBCAAAGIsgBAAAjEUQAgAAxiIIAQAAYxGEAACAsQhCAADAWAQhAABgLIIQAAAwFkEIAAAYiyAEAACMRRACAADGIggBAABjEYQAAICxCEIAAMBYBCEAAGAsghAAADAWQQgAABiLIAQAAIxFEAIAAMYiCAEAAGMRhAAAgLEIQgAAwFgEIQAAYCyCEAAAMBZBCAAAGIsgBAAAjEUQAgAAxiIIAQAAYxGEAACAsQhCAADAWAQhAABgLIIQAAAwFkEIAAAYiyAEAACMRRACAADGIggBAABjEYQAAICxCEIAAMBYBCEAAGAsghAAADAWQQgAABiLIAQAAIxFEAIAAMYiCAEAAGO1Oght3bpVd999t5KSkhQSEqJ169b57ff5fMrPz1fPnj0VFRWljIwMffrpp341x48f14QJExQTE6OuXbtqypQpOn36tF/NRx99pNtuu02RkZFKTk5WUVFRo17eeOMN9evXT5GRkRo8eLDefffdVvcCAADM1eogVFtbq6FDh6q4uLjJ/UVFRVq2bJlKSkq0c+dOde7cWXa7XWfPnrVqJkyYoAMHDsjhcGj9+vXaunWrHn/8cWu/y+VSZmam+vTpo6qqKr3wwgsqKCjQK6+8YtVs375dDz74oKZMmaK9e/dq3LhxGjdunPbv39+qXgAAgLk6tfYTxo4dq7Fjxza5z+fzacmSJZozZ47uueceSdLvfvc7JSQkaN26dXrggQf0ySefqKysTB9++KGGDx8uSXrxxRd111136Te/+Y2SkpK0evVq1dXVacWKFYqIiNDAgQNVXV2tRYsWWYFp6dKlGjNmjGbMmCFJWrBggRwOh1566SWVlJRcUi8AAMBsbfoaocOHD8vpdCojI8PaFhsbq7S0NFVWVkqSKisr1bVrVysESVJGRoZCQ0O1c+dOq+b2229XRESEVWO323Xo0CGdOHHCqjn/fhpqGu7nUnoBAABma/UZoeY4nU5JUkJCgt/2hIQEa5/T6VR8fLx/E506KS4uzq8mJSWl0TEa9nXr1k1Op7PF+2mplwu53W653W7rtsvlkiR5PB55PJ7mRr/iNMwTbHM1xRbma3p7qM/vT9Mw/5U5f1t+z5r0ONAU0+eXgncNWjNPmwahK11hYaHmzZvXaHt5ebmio6MD0FH7czgcgW6h3RWNaH7/guHejmnkMsX8V9b8F14U0hZMeBxojunzS8G3BmfOnLnk2jYNQomJiZKkmpoa9ezZ09peU1OjYcOGWTXHjh3z+7xz587p+PHj1ucnJiaqpqbGr6bhdks15+9vqZcLzZ49W3l5edZtl8ul5ORkZWZmKiYmpuUFuIJ4PB45HA6NHj1a4eHhgW6nXQ0q2NjkdluoTwuGezV3d6jc3pAO7irwmP/KnH9/gb3NjmXS40BTTJ9fCt41aHhG51K0aRBKSUlRYmKiKioqrLDhcrm0c+dOTZs2TZKUnp6ukydPqqqqSqmpqZKkTZs2yev1Ki0tzar5xS9+IY/HY/3DOBwOXXfdderWrZtVU1FRodzcXOv+HQ6H0tPTL7mXC9lsNtlstkbbw8PDg+oL5HzBPFsDd33z/8m5vSEt1gQz5r+y5m+P71cTHgeaY/r8UvCtQWtmafWLpU+fPq3q6mpVV1dL+seLkqurq3XkyBGFhIQoNzdXzz33nN5++23t27dPDz/8sJKSkjRu3DhJUv/+/TVmzBhNnTpVu3bt0p///Gfl5OTogQceUFJSkiTpoYceUkREhKZMmaIDBw5ozZo1Wrp0qd/ZmqeeekplZWVauHChDh48qIKCAu3evVs5OTmSdEm9AAAAs7X6jNDu3bt1xx13WLcbwsmkSZO0atUqPf3006qtrdXjjz+ukydP6tZbb1VZWZkiIyOtz1m9erVycnI0atQohYaGavz48Vq2bJm1PzY2VuXl5crOzlZqaqp69Oih/Px8v/cauvnmm1VaWqo5c+bomWee0bXXXqt169Zp0KBBVs2l9AIAAMzV6iA0cuRI+XwXv8oiJCRE8+fP1/z58y9aExcXp9LS0mbvZ8iQIfrggw+arbnvvvt03333/VO9AAAAc/G7xgAAgLEIQgAAwFgEIQAAYCyCEAAAMBZBCAAAGIsgBAAAjEUQAgAAxiIIAQAAYxGEAACAsQhCAADAWAQhAABgLIIQAAAwFkEIAAAYiyAEAACMRRACAADGIggBAABjEYQAAICxCEIAAMBYBCEAAGAsghAAADAWQQgAABiLIAQAAIxFEAIAAMYiCAEAAGMRhAAAgLEIQgAAwFgEIQAAYCyCEAAAMBZBCAAAGIsgBAAAjEUQAgAAxiIIAQAAYxGEAACAsQhCAADAWAQhAABgLIIQAAAwFkEIAAAYiyAEAACMRRACAADGIggBAABjEYQAAICxCEIAAMBYBCEAAGAsghAAADAWQQgAABiLIAQAAIxFEAIAAMYiCAEAAGMRhAAAgLEIQgAAwFgEIQAAYKw2D0IFBQUKCQnx++jXr5+1/+zZs8rOzlb37t111VVXafz48aqpqfE7xpEjR5SVlaXo6GjFx8drxowZOnfunF/N5s2bdcMNN8hms+maa67RqlWrGvVSXFysvn37KjIyUmlpadq1a1dbjwsAAK5g7XJGaODAgfr666+tj23btln7pk+frnfeeUdvvPGGtmzZoqNHj+ree++19tfX1ysrK0t1dXXavn27Xn/9da1atUr5+flWzeHDh5WVlaU77rhD1dXVys3N1WOPPaaNGzdaNWvWrFFeXp6effZZ7dmzR0OHDpXdbtexY8faY2QAAHAFapcg1KlTJyUmJlofPXr0kCSdOnVKr732mhYtWqQ777xTqampWrlypbZv364dO3ZIksrLy/Xxxx/r97//vYYNG6axY8dqwYIFKi4uVl1dnSSppKREKSkpWrhwofr376+cnBz99Kc/1eLFi60eFi1apKlTp2ry5MkaMGCASkpKFB0drRUrVrTHyAAA4ArUqT0O+umnnyopKUmRkZFKT09XYWGhevfuraqqKnk8HmVkZFi1/fr1U+/evVVZWambbrpJlZWVGjx4sBISEqwau92uadOm6cCBA7r++utVWVnpd4yGmtzcXElSXV2dqqqqNHv2bGt/aGioMjIyVFlZedG+3W633G63ddvlckmSPB6PPB7PP7Uml5uGeYJtrqbYwnxNbw/1+f1pGua/Mue/7hfr2+xYtlCfFgyXUueXye0NabPjNmV/gb1dj/9DmPQ4eDHBugatmafNg1BaWppWrVql6667Tl9//bXmzZun2267Tfv375fT6VRERIS6du3q9zkJCQlyOp2SJKfT6ReCGvY37GuuxuVy6fvvv9eJEydUX1/fZM3Bgwcv2nthYaHmzZvXaHt5ebmio6MvbQGuMA6HI9AttLuiEc3vXzDc2zGNXKaY3+z5pY5Zg3fffbfd7+OHMuFxsCXBtgZnzpy55No2D0Jjx461/j5kyBClpaWpT58+Wrt2raKiotr67trU7NmzlZeXZ912uVxKTk5WZmamYmJiAthZ2/N4PHI4HBo9erTCw8MD3U67GlSwscnt//hp2Ku5u0Pb/afhyxHzmz2/1LFrcLmeETLlcfBignUNGp7RuRTt8tTY+bp27ap//dd/1WeffabRo0errq5OJ0+e9DsrVFNTo8TERElSYmJio6u7Gq4qO7/mwivNampqFBMTo6ioKIWFhSksLKzJmoZjNMVms8lmszXaHh4eHlRfIOcL5tkauOubf4B3e0NarAlmzG/2/FLHrMHl/DhjwuNgS4JtDVozS7u/j9Dp06f1+eefq2fPnkpNTVV4eLgqKiqs/YcOHdKRI0eUnp4uSUpPT9e+ffv8ru5yOByKiYnRgAEDrJrzj9FQ03CMiIgIpaam+tV4vV5VVFRYNQAAAG0ehP7rv/5LW7Zs0d///ndt375dP/nJTxQWFqYHH3xQsbGxmjJlivLy8vT++++rqqpKkydPVnp6um666SZJUmZmpgYMGKCJEyfqL3/5izZu3Kg5c+YoOzvbOlvzxBNP6G9/+5uefvppHTx4UC+//LLWrl2r6dOnW33k5eXpt7/9rV5//XV98sknmjZtmmprazV58uS2HhkAAFyh2vypsa+++koPPvigvv32W/3oRz/Srbfeqh07duhHP/qRJGnx4sUKDQ3V+PHj5Xa7Zbfb9fLLL1ufHxYWpvXr12vatGlKT09X586dNWnSJM2fP9+qSUlJ0YYNGzR9+nQtXbpUvXr10quvviq7/f8/B33//ffrm2++UX5+vpxOp4YNG6aysrJGL6AGAADmavMg9Mc//rHZ/ZGRkSouLlZxcfFFa/r06dPiFQYjR47U3r17m63JyclRTk5OszUAAMBc/K4xAABgLIIQAAAwVrtfPo/g0nfWhkC3AABAm+GMEAAAMBZBCAAAGIsgBAAAjEUQAgAAxiIIAQAAYxGEAACAsQhCAADAWAQhAABgLIIQAAAwFkEIAAAYiyAEAACMRRACAADGIggBAABjEYQAAICxCEIAAMBYBCEAAGAsghAAADAWQQgAABiLIAQAAIxFEAIAAMYiCAEAAGMRhAAAgLEIQgAAwFidAt0AACD49Z21IdAtNGIL86lohDSoYKPc9SGN9v/911kB6AodjTNCAADAWJwRCqBA/oTU0k9CAACYgDNCAADAWAQhAABgLIIQAAAwFkEIAAAYiyAEAACMRRACAADGIggBAABjEYQAAICxCEIAAMBYBCEAAGAsghAAADAWQQgAABiLIAQAAIxFEAIAAMYiCAEAAGMRhAAAgLEIQgAAwFgEIQAAYCyCEAAAMBZBCAAAGIsgBAAAjEUQAgAAxjIiCBUXF6tv376KjIxUWlqadu3aFeiWAADAZSDog9CaNWuUl5enZ599Vnv27NHQoUNlt9t17NixQLcGAAACLOiD0KJFizR16lRNnjxZAwYMUElJiaKjo7VixYpAtwYAAAKsU6AbaE91dXWqqqrS7NmzrW2hoaHKyMhQZWVlo3q32y23223dPnXqlCTp+PHj8ng8bd5fp3O1bX7MS75vr09nznjVyROqem9IwPoIJNPXgPnNnl9iDVqa/5r/WhuArv45O2ePalW9x+PRmTNn9O233yo8PLyduup43333nSTJ5/O1WBvUQej//u//VF9fr4SEBL/tCQkJOnjwYKP6wsJCzZs3r9H2lJSUdusxkB4KdAOXAdPXgPlh+hoE2/w9Fga6g8vLd999p9jY2GZrgjoItdbs2bOVl5dn3fZ6vTp+/Li6d++ukJDg+mnJ5XIpOTlZX375pWJiYgLdTkCYvgbMb/b8Emtg+vxS8K6Bz+fTd999p6SkpBZrgzoI9ejRQ2FhYaqpqfHbXlNTo8TExEb1NptNNpvNb1vXrl3bs8WAi4mJCaov/h/C9DVgfrPnl1gD0+eXgnMNWjoT1CCoXywdERGh1NRUVVRUWNu8Xq8qKiqUnp4ewM4AAMDlIKjPCElSXl6eJk2apOHDh2vEiBFasmSJamtrNXny5EC3BgAAAizog9D999+vb775Rvn5+XI6nRo2bJjKysoavYDaNDabTc8++2yjpwJNYvoaML/Z80usgenzS6yBJIX4LuXaMgAAgCAU1K8RAgAAaA5BCAAAGIsgBAAAjEUQAgAAxiIIGaawsFA33nijunTpovj4eI0bN06HDh0KdFsB8+tf/1ohISHKzc0NdCsd5n//93/1H//xH+revbuioqI0ePBg7d69O9BtdZj6+nrNnTtXKSkpioqK0tVXX60FCxZc0u8kulJt3bpVd999t5KSkhQSEqJ169b57ff5fMrPz1fPnj0VFRWljIwMffrpp4Fpth00N7/H49HMmTM1ePBgde7cWUlJSXr44Yd19OjRwDXcxlr69z/fE088oZCQEC1ZsqTD+gs0gpBhtmzZouzsbO3YsUMOh0Mej0eZmZmqrQ3cL4ANlA8//FD//d//rSFDhgS6lQ5z4sQJ3XLLLQoPD9f//M//6OOPP9bChQvVrVu3QLfWYZ5//nktX75cL730kj755BM9//zzKioq0osvvhjo1tpNbW2thg4dquLi4ib3FxUVadmyZSopKdHOnTvVuXNn2e12nT17toM7bR/NzX/mzBnt2bNHc+fO1Z49e/SnP/1Jhw4d0r//+78HoNP20dK/f4M333xTO3bsuKRfSxFUfDDasWPHfJJ8W7ZsCXQrHeq7777zXXvttT6Hw+H7t3/7N99TTz0V6JY6xMyZM3233nproNsIqKysLN+jjz7qt+3ee+/1TZgwIUAddSxJvjfffNO67fV6fYmJib4XXnjB2nby5EmfzWbz/eEPfwhAh+3rwvmbsmvXLp8k3xdffNExTXWgi83/1Vdf+f7lX/7Ft3//fl+fPn18ixcv7vDeAoUzQoY7deqUJCkuLi7AnXSs7OxsZWVlKSMjI9CtdKi3335bw4cP13333af4+Hhdf/31+u1vfxvotjrUzTffrIqKCv31r3+VJP3lL3/Rtm3bNHbs2AB3FhiHDx+W0+n0+16IjY1VWlqaKisrA9hZ4Jw6dUohISFB/7smG3i9Xk2cOFEzZszQwIEDA91Ohwv6d5bGxXm9XuXm5uqWW27RoEGDAt1Oh/njH/+oPXv26MMPPwx0Kx3ub3/7m5YvX668vDw988wz+vDDD/Xzn/9cERERmjRpUqDb6xCzZs2Sy+VSv379FBYWpvr6ev3yl7/UhAkTAt1aQDidTklq9G77CQkJ1j6TnD17VjNnztSDDz4YdL+E9GKef/55derUST//+c8D3UpAEIQMlp2drf3792vbtm2BbqXDfPnll3rqqafkcDgUGRkZ6HY6nNfr1fDhw/WrX/1KknT99ddr//79KikpMSYIrV27VqtXr1ZpaakGDhyo6upq5ebmKikpyZg1QNM8Ho9+9rOfyefzafny5YFup0NUVVVp6dKl2rNnj0JCQgLdTkDw1JihcnJytH79er3//vvq1atXoNvpMFVVVTp27JhuuOEGderUSZ06ddKWLVu0bNkyderUSfX19YFusV317NlTAwYM8NvWv39/HTlyJEAddbwZM2Zo1qxZeuCBBzR48GBNnDhR06dPV2FhYaBbC4jExERJUk1Njd/2mpoaa58JGkLQF198IYfDYczZoA8++EDHjh1T7969rcfEL774Qv/5n/+pvn37Brq9DsEZIcP4fD49+eSTevPNN7V582alpKQEuqUONWrUKO3bt89v2+TJk9WvXz/NnDlTYWFhAeqsY9xyyy2N3i7hr3/9q/r06ROgjjremTNnFBrq/zNgWFiYvF5vgDoKrJSUFCUmJqqiokLDhg2TJLlcLu3cuVPTpk0LbHMdpCEEffrpp3r//ffVvXv3QLfUYSZOnNjotZJ2u10TJ07U5MmTA9RVxyIIGSY7O1ulpaV666231KVLF+s1ALGxsYqKigpwd+2vS5cujV4P1blzZ3Xv3t2I10lNnz5dN998s371q1/pZz/7mXbt2qVXXnlFr7zySqBb6zB33323fvnLX6p3794aOHCg9u7dq0WLFunRRx8NdGvt5vTp0/rss8+s24cPH1Z1dbXi4uLUu3dv5ebm6rnnntO1116rlJQUzZ07V0lJSRo3blzgmm5Dzc3fs2dP/fSnP9WePXu0fv161dfXW4+LcXFxioiICFTbbaalf/8Lg194eLgSExN13XXXdXSrgRHoy9bQsSQ1+bFy5cpAtxYwJl0+7/P5fO+8845v0KBBPpvN5uvXr5/vlVdeCXRLHcrlcvmeeuopX+/evX2RkZG+H//4x75f/OIXPrfbHejW2s3777/f5Pf9pEmTfD7fPy6hnzt3ri8hIcFns9l8o0aN8h06dCiwTbeh5uY/fPjwRR8X33///UC33iZa+ve/kGmXz4f4fEH8dqoAAADN4MXSAADAWAQhAABgLIIQAAAwFkEIAAAYiyAEAACMRRACAADGIggBAABjEYQAAICxCEIAAMBYBCEAAGAsghAAADAWQQgAABjr/wESCjxMkjUgKwAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "news['line_length'].hist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "05188d98-d126-46f7-aa2b-f9d33483575d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4159, 2049)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train, X_test= train_test_split(news.headline_text.sample(int(0.005*news.shape[0])).tolist(),test_size=0.33, random_state=42)\n",
    "len(X_train), len(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c606838c-86f1-4ca1-ab82-41dd14915a22",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('train_dataset.txt','w') as f:\n",
    "  for line in X_train:\n",
    "    f.write(line)\n",
    "    f.write(\"\\n\")\n",
    "\n",
    "with open('test_dataset.txt','w') as f:\n",
    "  for line in X_test:\n",
    "    f.write(line)\n",
    "    f.write(\"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "75e15f55-5921-4489-9a9d-2544626b6acf",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/raghavbali/.pyenv/versions/3.11.9/envs/deeplearning/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\",pad_token='<pad>')\n",
    "\n",
    "train_path = './train_dataset.txt'\n",
    "test_path = './test_dataset.txt'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "cc598a4e-6ea4-41b3-8241-57ea59c21b59",
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "# train_dataset = datasets.load_dataset('csv',data_files=train_path)\n",
    "# test_dataset = datasets.load_dataset('csv',data_files=test_path)\n",
    "dataset = datasets.load_dataset('text',data_files={'train':train_path,'test':test_path})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8d2dad66-49cd-4af9-9133-9d2add0d1bd0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a9abfcfb3c95495295c5d3d4e2567ca6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map (num_proc=8):   0%|          | 0/4159 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2cd2f0cc753b4636b4f63be06aa2fd82",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map (num_proc=8):   0%|          | 0/2049 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def tokenize_function(examples):\n",
    "    # Remove empty lines\n",
    "    examples[\"text\"] = [line+'<|endoftext|>' for line in examples[\"text\"] if len(line) > 0 and not line.isspace()]\n",
    "    return tokenizer(\n",
    "        examples[\"text\"],\n",
    "        truncation=True,\n",
    "    )\n",
    "\n",
    "tokenized_train_dataset = dataset['train'].map(\n",
    "    tokenize_function,\n",
    "    batched=True,\n",
    "    num_proc=8,\n",
    "    remove_columns=[\"text\"],\n",
    ")\n",
    "\n",
    "tokenized_test_dataset = dataset['test'].map(\n",
    "    tokenize_function,\n",
    "    batched=True,\n",
    "    num_proc=8,\n",
    "    remove_columns=[\"text\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "42589349-ece5-4b79-9159-a81112e500df",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_collator = DataCollatorForLanguageModeling(\n",
    "    tokenizer=tokenizer, mlm=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "035383de-b4fb-470e-8448-d53c62c60911",
   "metadata": {},
   "source": [
    "## Prepare Model for Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0e401adf-db30-4675-a11b-bce1aa031457",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "93f49cd0cecb4a719f3830b993b55b44",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d7d61175-df48-426c-8180-b4f445a88587",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Backend Accelerator Device=mps\n"
     ]
    }
   ],
   "source": [
    "if torch.cuda.is_available():\n",
    "    DEVICE = 'cuda'\n",
    "    Tensor = torch.cuda.FloatTensor\n",
    "    LongTensor = torch.cuda.LongTensor\n",
    "    DEVICE_ID = 0\n",
    "# MPS/Apple Silicon does not work as intended for this pipeline    \n",
    "elif torch.backends.mps.is_available():\n",
    "    DEVICE = 'mps'\n",
    "    Tensor = torch.FloatTensor\n",
    "    LongTensor = torch.LongTensor\n",
    "    DEVICE_ID = 0\n",
    "else:\n",
    "    DEVICE = 'cpu'\n",
    "    Tensor = torch.FloatTensor\n",
    "    LongTensor = torch.LongTensor\n",
    "    DEVICE_ID = -1\n",
    "print(f\"Backend Accelerator Device={DEVICE}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "fa7d8612-00b3-4078-b6a0-83d2bc0f1404",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.mps.set_per_process_memory_fraction(0.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "d75f6b17-f49a-4284-822c-19903227120c",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = AutoModelForCausalLM.from_pretrained(\"openai-community/gpt2-medium\").to(DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "1ff3d217-f33f-41f5-93fb-b462b88036d3",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/raghavbali/.pyenv/versions/3.11.9/envs/deeplearning/lib/python3.11/site-packages/transformers/training_args.py:2179: UserWarning: `use_mps_device` is deprecated and will be removed in version 5.0 of 🤗 Transformers. `mps` device will be used by default if available similar to the way `cuda` device is used.Therefore, no action from user is required. \n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "training_args = TrainingArguments(\n",
    "    \"gpt2-finetuned-headliner\", #The output directory\n",
    "    overwrite_output_dir=True, #overwrite the content of the output directory\n",
    "    num_train_epochs=2, # number of training epochs\n",
    "    per_device_train_batch_size=64, # batch size for training\n",
    "    per_device_eval_batch_size=64,  # batch size for evaluation\n",
    "    eval_steps = 4, # Number of update steps between two evaluations.\n",
    "    save_steps=8, # after # steps model is saved \n",
    "    warmup_steps=4,# number of warmup steps for learning rate scheduler\n",
    "    logging_steps=8,\n",
    "    push_to_hub=True,\n",
    "    use_mps_device=True\n",
    "    #use_cpu=True # comment this if you have GPU available\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "74cb3f58-3f93-45a1-97e4-4204c3cb967a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.backends.mps.is_available()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "148ff44c-f82e-41de-a595-23aafef9ea93",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    data_collator=data_collator,\n",
    "    train_dataset=tokenized_train_dataset,\n",
    "    eval_dataset=tokenized_test_dataset,\n",
    "    #prediction_loss_only=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "6b22d1aa-0820-4096-8a3d-d80244d4a619",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='130' max='130' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [130/130 02:19, Epoch 2/2]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>6.501000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>16</td>\n",
       "      <td>5.618000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>24</td>\n",
       "      <td>5.407000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>32</td>\n",
       "      <td>5.382200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>40</td>\n",
       "      <td>5.262300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>48</td>\n",
       "      <td>5.105400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>56</td>\n",
       "      <td>5.107700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>64</td>\n",
       "      <td>5.169300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>72</td>\n",
       "      <td>4.687800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>80</td>\n",
       "      <td>4.680900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>88</td>\n",
       "      <td>4.734500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>96</td>\n",
       "      <td>4.685100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>104</td>\n",
       "      <td>4.556900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>112</td>\n",
       "      <td>4.660200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>120</td>\n",
       "      <td>4.634800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>128</td>\n",
       "      <td>4.633200</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=130, training_loss=5.044873604407678, metrics={'train_runtime': 140.587, 'train_samples_per_second': 59.166, 'train_steps_per_second': 0.925, 'total_flos': 248723096358912.0, 'train_loss': 5.044873604407678, 'epoch': 2.0})"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "c823eb84-895c-4745-bac0-054772cedab7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# trainer.save_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "f2e812b9-38a4-43cd-aac1-3f9cac8d41e8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e16c98a2e9cc49b3a9be5c82787a7ea5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors:   0%|          | 0.00/1.42G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b15cec162f26471e965b95c33d7eb885",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "training_args.bin:   0%|          | 0.00/5.18k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8fc4eff92f8a426e88724bffdb2a0f92",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "CommitInfo(commit_url='https://huggingface.co/raghavbali/gpt2-finetuned-headliner/commit/b248c852d95c92139ac25712f852032663658b8c', commit_message='End of training', commit_description='', oid='b248c852d95c92139ac25712f852032663658b8c', pr_url=None, pr_revision=None, pr_num=None)"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# trainer.push_to_hub()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c382996-2764-45c4-bc80-0148a9162e02",
   "metadata": {},
   "source": [
    "## Let us Generate Some Headlines!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "d73670c5-a1f6-40e6-998c-56165b26654f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/raghavbali/.pyenv/versions/3.11.9/envs/deeplearning/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
      "  warnings.warn(\n",
      "Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.\n"
     ]
    }
   ],
   "source": [
    "# load the fine-tuned model\n",
    "ft_gpt2_headliner = AutoModelForCausalLM.from_pretrained(\"./gpt2-finetuned-headliner\").to('cpu')\n",
    "\n",
    "# setup the generation pipeline\n",
    "headliner = pipeline('text-generation',\n",
    "                     model=ft_gpt2_headliner, \n",
    "                     tokenizer='gpt2',\n",
    "                     pad_token_id=0,\n",
    "                     eos_token_id=50256,\n",
    "                     config={\n",
    "                         'max_length':8,\n",
    "                     },\n",
    "                    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "a2f94852-1ffa-4c03-a9a7-41435c6b02d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_headline(headliner_pipeline, seed_text=\"News\"):\n",
    "  return headliner_pipeline(seed_text)[0]['generated_text'].split('\\n')[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "5912e2af-6576-49fd-9ad9-63b43da657cf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'News Zealand man charged with assault after being hit by a van'"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_headline(headliner, seed_text=\"News Zealand man charged with assault\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a83bbd99-9803-4524-be91-df8a91f9c316",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
